from typing import Any, Dict, Generator, List, Optional, Union

import torch as th
import numpy as np
from gym import spaces
from stable_baselines3.common.buffers import BaseBuffer
from algos.arppo.type_aliases import CustomRolloutBufferSamples
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.running_mean_std import RunningMeanStd

import pdb

class CustomRolloutBuffer(BaseBuffer):
    """
    Rollout buffer used in on-policy algorithms like A2C/PPO.
    It corresponds to ``buffer_size`` transitions collected
    using the current policy.
    This experience will be discarded after the policy update.
    In order to use PPO objective, we also store the current value of each state
    and the log probability of each taken action.
    The term rollout here refers to the model-free notion and should not
    be used with the concept of rollout used in model-based RL or planning.
    Hence, it is only involved in policy and value function training but not action selection.
    :param buffer_size: Max number of element in the buffer
    :param observation_space: Observation space
    :param action_space: Action space
    :param device: PyTorch device
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
        Equivalent to classic advantage when set to 1.
    :param gamma: Discount factor
    :param n_envs: Number of parallel environments
    """

    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        device: Union[th.device, str] = "auto",
        gae_lambda: float = 1,
        gamma: float = 0.99,
        n_envs: int = 1,
        variant: str = 'mark-algo3',
        normalize_rewards = False,
        action_mask_extractor = None
    ):

        super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
        self.gae_lambda = gae_lambda
        self.gamma = gamma
        self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
        self.returns, self.value_targets, self.episode_starts, self.values, self.log_probs = None, None, None, None, None
        self.generator_ready = False
        self.variant = variant
        self.rew_rms = RunningMeanStd(shape=()) # never resets when .reset() called
        self.normalize_rewards = normalize_rewards
        self.action_mask_extractor = action_mask_extractor
        self.reset()

    def reset(self) -> None:

        self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
        self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.value_targets = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.valid_obs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.generator_ready = False
        super().reset()

    def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
        """
        Post-processing step: compute the lambda-return (TD(lambda) estimate)
        and GAE(lambda) advantage.
        Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
        to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
        where R is the sum of discounted reward with value bootstrap
        (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
        The TD(lambda) estimator has also two special cases:
        - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
        - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
        For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
        :param last_values: state value estimation for the last step (one for each env)
        :param dones: if the last step was a terminal step (one bool for each env).
        """

        last_values = last_values.clone().cpu().numpy().flatten()

        assert self.variant is not None
        # per env
        # TODO verify when n_env > 1
        if self.normalize_rewards:
            self.rewards = np.clip(self.rewards / np.sqrt(self.rew_rms.var + 1e-8), -20, 20)
       
        mean_rew = np.mean(self.rewards, axis = 0)
        last_gae_lam = 0
        for step in reversed(range(self.buffer_size)):
            if step == self.buffer_size - 1:
                next_non_terminal = 1.0 - dones
                next_values = last_values
            else:
                next_non_terminal = 1.0 - self.episode_starts[step + 1]
                next_values = self.values[step + 1]

            if self.variant == 'mark-algo3':
                # corresponds to (4.24) in paper, not original (4.21) due to mean reward
                delta = self.rewards[step] - mean_rew + self.gamma * next_values * next_non_terminal - self.values[step]  # subtracting mean reward
                last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
                # mark's calculation of advantage is just the single step delta instead of an accumulation
                # the accumulation using lambda is implicitly done in the critic computation
                self.value_targets[step] = last_gae_lam + self.values[step]

                # same as delta without gamma
                self.advantages[step] = self.rewards[step] - mean_rew + next_values * next_non_terminal - self.values[step]
            elif self.variant == 'zhang':
                sub_diff = self.rewards[step] - mean_rew
                # average reward
                self.value_targets[step] = sub_diff + next_values * next_non_terminal
                self.advantages[step] = sub_diff + next_values * next_non_terminal - self.values[step]
            elif self.variant == 'discount':
                # GAE
                delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
                last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
                self.advantages[step] = last_gae_lam

                # discounted return
                #self.value_targets[step] = self.rewards[step] + self.gamma * next_values * next_non_terminal
                #self.advantages[step] = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]

            # if self.variant == 'discount':
            #     self.value_targets = self.advantages + self.values

        # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
        # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
        if self.variant == 'discount':
            self.value_targets = self.advantages + self.values
        #self.returns = self.advantages + self.values
        #self.value_targets = self.advantages + self.values

    def add(
        self,
        obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        episode_start: np.ndarray,
        value: th.Tensor,
        log_prob: th.Tensor
    ) -> None:
        """
        :param obs: Observation
        :param action: Action
        :param reward:
        :param episode_start: Start of episode signal.
        :param value: estimated value of the current state
            following the current policy.
        :param log_prob: log probability of the action
            following the current policy.
        """
        #valid = True#(1 - np.all(self.action_mask_extractor(obs), axis = 1)).astype(bool)[0]
        #x = self.action_mask_extractor(obs)
        valid = True#np.sum((1 - x), axis = 1)[0] > 1

        if valid:
            if len(log_prob.shape) == 0:
                # Reshape 0-d tensor to avoid error
                log_prob = log_prob.reshape(-1, 1)

            # Reshape needed when using multiple envs with discrete observations
            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
            if isinstance(self.observation_space, spaces.Discrete):
                obs = obs.reshape((self.n_envs,) + self.obs_shape)

            # Same reshape, for actions
            action = action.reshape((self.n_envs, self.action_dim))
            self.rew_rms.update(reward)
            self.observations[self.pos] = np.array(obs).copy()
            self.actions[self.pos] = np.array(action).copy()
            self.rewards[self.pos] = np.array(reward).copy()
            self.episode_starts[self.pos] = np.array(episode_start).copy()
            self.values[self.pos] = value.clone().cpu().numpy().flatten()
            self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
            self.valid_obs[self.pos] = 1
            self.pos += 1
            if self.pos == self.buffer_size:
                self.full = True

    def get(self, batch_size: Optional[int] = None) -> Generator[CustomRolloutBufferSamples, None, None]:
        #assert self.full, ""
        indices = np.random.permutation(self.pos * self.n_envs)
        #indices = np.random.permutation(self.buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:

            _tensor_names = [
                "observations",
                "actions",
                "values",
                "log_probs",
                "advantages",
                "returns",
                "value_targets",
                "rewards"
            ]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            #batch_size = self.buffer_size * self.n_envs
            batch_size = self.pos * self.n_envs


        start_idx = 0
        #while start_idx < self.buffer_size * self.n_envs:
        while start_idx < self.pos * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> CustomRolloutBufferSamples:
        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
            self.value_targets[batch_inds].flatten(),
            self.rewards[batch_inds].flatten()
        )
        return CustomRolloutBufferSamples(*tuple(map(self.to_torch, data)))